import os import cv2 import numpy as np import pickle import gradio as gr # Import the feature extraction function from feature_extractor.py from feature_extractor import extract_features_from_image # Global variables for the models, class names, and training log models = {} # This will be a dictionary with keys: 'svm', 'rf', 'combined' class_names = [] training_log = "" # --------------------------------------------------------------------- # Model Loading # --------------------------------------------------------------------- def load_model(model_filename): global models, class_names, training_log if os.path.exists(model_filename): print("Found existing model file. Loading...") with open(model_filename, "rb") as f: model_data = pickle.load(f) models = model_data['models'] # Expecting a dict: {'svm': ..., 'rf': ..., 'combined': ...} class_names = model_data['class_names'] training_log += "Loaded model from disk.\n" print("Loaded models from disk.") else: print(f"Model file {model_filename} not found. Please train the model first.") # --------------------------------------------------------------------- # Gradio Classification Function with Model Selection # --------------------------------------------------------------------- def classify_new_image(input_image_path, model_choice): """ Expects input_image_path as a file path and model_choice as one of the keys in models. Loads the image, processes it by extracting patches, classifies each patch, and aggregates patch predictions. Also, draws transparent overlays on each patch according to its predicted label. Returns: annotated_image_rgb (numpy array): The image with transparent overlays. final_prediction (str): The final predicted class. prob_dict (dict): Dictionary of class probabilities. """ global models, training_log, class_names progress_log = training_log + "\nStarting classification...\n" if model_choice not in models: raise ValueError(f"Model choice '{model_choice}' not found. Available choices: {list(models.keys())}") classifier = models[model_choice] # Load image using OpenCV from file path image = cv2.imread(input_image_path) if image is None: raise ValueError("Error: Could not load image from the provided file path.") # Resize the image to a fixed width (1000 px) while maintaining aspect ratio fixed_width = 1000 height, width = image.shape[:2] aspect_ratio = height / width new_height = int(fixed_width * aspect_ratio) resized_image = cv2.resize(image, (fixed_width, new_height)) progress_log += "Resized image to fixed width of 1000 pixels.\n" # The image from cv2.imread is already in BGR format. image_bgr = resized_image progress_log += "Image loaded in BGR format.\n" # Preprocessing – Convert to grayscale, apply Gaussian blur, and compute edges gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY) blurred = cv2.GaussianBlur(gray, (9, 9), 0) edges = cv2.Canny(blurred, threshold1=0, threshold2=100) progress_log += "Computed edges using Canny edge detection.\n" # Patch extraction parameters patch_size = (100, 100) patch_w, patch_h = patch_size img_h, img_w = gray.shape valid_patch_count = 0 summed_probabilities = None overlays_list = [] # To store (x, y, w, h, predicted_label) for each valid patch # Loop over non-overlapping patches for y in range(0, img_h - patch_h + 1, patch_h): for x in range(0, img_w - patch_w + 1, patch_w): patch_edges = edges[y:y+patch_h, x:x+patch_w] patch = resized_image[y:y+patch_h, x:x+patch_w] num_edge_pixels = np.sum(patch_edges > 0) total_pixels = patch_w * patch_h density = num_edge_pixels / total_pixels progress_log += f"Patch at ({x}, {y}) - edge density: {density:.3f}\n" if 0.0 < density < 0.5: valid_patch_count += 1 features = extract_features_from_image(patch) feature_vector = features['combined_features'].reshape(1, -1) patch_probabilities = classifier.predict_proba(feature_vector)[0] predicted_index = np.argmax(patch_probabilities) predicted_label = class_names[predicted_index] progress_log += f"Patch at ({x}, {y}) predicted: {predicted_label} with probabilities {patch_probabilities}\n" overlays_list.append((x, y, patch_w, patch_h, predicted_label)) if summed_probabilities is None: summed_probabilities = patch_probabilities else: summed_probabilities += patch_probabilities # Fallback: if no valid patches are found, classify the whole image. if valid_patch_count == 0: progress_log += "No valid patches found. Falling back to whole image classification.\n" features = extract_features_from_image(image_bgr) feature_vector = features['combined_features'].reshape(1, -1) summed_probabilities = classifier.predict_proba(feature_vector)[0] valid_patch_count = 1 # Average the probabilities from all valid patches and normalize them averaged_probabilities = summed_probabilities / valid_patch_count normalized_probabilities = averaged_probabilities / np.sum(averaged_probabilities) final_prediction_index = np.argmax(normalized_probabilities) final_prediction = class_names[final_prediction_index] prob_dict = {cls: float(normalized_probabilities[i]) for i, cls in enumerate(class_names)} progress_log += "Classification completed.\n" print(progress_log) print(prob_dict) # Create an annotated image with transparent overlays annotated_image = resized_image.copy() overlay = annotated_image.copy() alpha = 0.4 # Transparency factor # Define overlay colors in BGR for each class (adjust as desired) color_map = { 'wood': (0, 255, 255), # Yellow (BGR format) 'brick': (0, 0, 255), # Red (BGR format) 'stone': (128, 128, 128) # Gray (BGR format) } for (x, y, w, h, label) in overlays_list: label_lower = label.lower() if prob_dict[label] < 0.2: continue color = color_map.get(label_lower, (0, 255, 0)) # Default to green if unknown cv2.rectangle(overlay, (x, y), (x+w, y+h), color, thickness=-1) # Blend the overlay with the original image annotated_image = cv2.addWeighted(overlay, alpha, annotated_image, 1 - alpha, 0) # Convert annotated image from BGR to RGB for Gradio display annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) return final_prediction, prob_dict, annotated_image_rgb # --------------------------------------------------------------------- # Gradio Interface Setup using file paths and model selection # --------------------------------------------------------------------- if __name__ == "__main__": model_filename = "./svm_rf_combined.pkl" # Adjust filename as needed load_model(model_filename) # Create a dropdown for model selection. model_choices = list(models.keys()) if models else ['svm', 'rf', 'combined'] iface = gr.Interface( fn=classify_new_image, inputs=[ gr.Image(type="filepath", label="Input Image"), gr.Dropdown(choices=model_choices, label="Select Model", value=model_choices[0]) ], outputs=[ gr.Label(label="Predicted Class"), gr.Label(label="Probabilities"), gr.Image(label="Annotated Image") ], title="Stone, Wood, Brick Classifier", description=("Upload an image and select a classifier model (svm, rf, combined) to classify it.\n\n" "The image is processed by subdividing it into patches and aggregating the predictions. " "Transparent overlays are drawn on detected objects. Progress logs are printed to the terminal.") ) iface.launch(share=True)