File size: 8,278 Bytes
2ae889c
 
 
 
 
 
 
 
 
600cada
 
2ae889c
 
 
 
 
 
 
600cada
2ae889c
600cada
2ae889c
 
600cada
2ae889c
 
600cada
2ae889c
 
 
600cada
 
 
 
2ae889c
600cada
abd2768
 
 
 
 
 
2ae889c
600cada
2ae889c
 
600cada
 
 
 
2ae889c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abd2768
2ae889c
 
 
 
 
 
 
 
 
 
 
abd2768
2ae889c
 
 
 
abd2768
 
 
 
 
2ae889c
 
 
 
 
 
 
 
 
 
 
 
 
 
abd2768
2ae889c
 
 
 
 
 
 
 
 
 
abd2768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ae889c
 
600cada
2ae889c
 
600cada
2ae889c
 
600cada
 
 
2ae889c
 
600cada
 
 
 
2ae889c
 
abd2768
 
2ae889c
 
600cada
2ae889c
abd2768
2ae889c
600cada
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import cv2
import numpy as np
import pickle
import gradio as gr

# Import the feature extraction function from feature_extractor.py
from feature_extractor import extract_features_from_image

# Global variables for the models, class names, and training log
models = {}  # This will be a dictionary with keys: 'svm', 'rf', 'combined'
class_names = []
training_log = ""

# ---------------------------------------------------------------------
# Model Loading
# ---------------------------------------------------------------------
def load_model(model_filename):
    global models, class_names, training_log
    if os.path.exists(model_filename):
        print("Found existing model file. Loading...")
        with open(model_filename, "rb") as f:
            model_data = pickle.load(f)
        models = model_data['models']   # Expecting a dict: {'svm': ..., 'rf': ..., 'combined': ...}
        class_names = model_data['class_names']
        training_log += "Loaded model from disk.\n"
        print("Loaded models from disk.")
    else:
        print(f"Model file {model_filename} not found. Please train the model first.")

# ---------------------------------------------------------------------
# Gradio Classification Function with Model Selection
# ---------------------------------------------------------------------
def classify_new_image(input_image_path, model_choice):
    """
    Expects input_image_path as a file path and model_choice as one of the keys in models.
    Loads the image, processes it by extracting patches, classifies each patch, and aggregates
    patch predictions. Also, draws transparent overlays on each patch according to its predicted label.
    Returns:
        annotated_image_rgb (numpy array): The image with transparent overlays.
        final_prediction (str): The final predicted class.
        prob_dict (dict): Dictionary of class probabilities.
    """
    global models, training_log, class_names
    progress_log = training_log + "\nStarting classification...\n"
    
    if model_choice not in models:
        raise ValueError(f"Model choice '{model_choice}' not found. Available choices: {list(models.keys())}")
    classifier = models[model_choice]
    
    # Load image using OpenCV from file path
    image = cv2.imread(input_image_path)
    if image is None:
        raise ValueError("Error: Could not load image from the provided file path.")
    
    # Resize the image to a fixed width (1000 px) while maintaining aspect ratio
    fixed_width = 1000
    height, width = image.shape[:2]
    aspect_ratio = height / width
    new_height = int(fixed_width * aspect_ratio)
    resized_image = cv2.resize(image, (fixed_width, new_height))
    progress_log += "Resized image to fixed width of 1000 pixels.\n"
    
    # The image from cv2.imread is already in BGR format.
    image_bgr = resized_image
    progress_log += "Image loaded in BGR format.\n"
    
    # Preprocessing – Convert to grayscale, apply Gaussian blur, and compute edges
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (9, 9), 0)
    edges = cv2.Canny(blurred, threshold1=0, threshold2=100)
    progress_log += "Computed edges using Canny edge detection.\n"
    
    # Patch extraction parameters
    patch_size = (100, 100)
    patch_w, patch_h = patch_size
    img_h, img_w = gray.shape
    
    valid_patch_count = 0
    summed_probabilities = None
    overlays_list = []  # To store (x, y, w, h, predicted_label) for each valid patch
    
    # Loop over non-overlapping patches
    for y in range(0, img_h - patch_h + 1, patch_h):
        for x in range(0, img_w - patch_w + 1, patch_w):
            patch_edges = edges[y:y+patch_h, x:x+patch_w]
            patch = resized_image[y:y+patch_h, x:x+patch_w]
            num_edge_pixels = np.sum(patch_edges > 0)
            total_pixels = patch_w * patch_h
            density = num_edge_pixels / total_pixels
            progress_log += f"Patch at ({x}, {y}) - edge density: {density:.3f}\n"
            
            if 0.0 < density < 0.5:
                valid_patch_count += 1
                features = extract_features_from_image(patch)
                feature_vector = features['combined_features'].reshape(1, -1)
                patch_probabilities = classifier.predict_proba(feature_vector)[0]
                predicted_index = np.argmax(patch_probabilities)
                predicted_label = class_names[predicted_index]
                progress_log += f"Patch at ({x}, {y}) predicted: {predicted_label} with probabilities {patch_probabilities}\n"
                
                overlays_list.append((x, y, patch_w, patch_h, predicted_label))
                
                if summed_probabilities is None:
                    summed_probabilities = patch_probabilities
                else:
                    summed_probabilities += patch_probabilities
    
    # Fallback: if no valid patches are found, classify the whole image.
    if valid_patch_count == 0:
        progress_log += "No valid patches found. Falling back to whole image classification.\n"
        features = extract_features_from_image(image_bgr)
        feature_vector = features['combined_features'].reshape(1, -1)
        summed_probabilities = classifier.predict_proba(feature_vector)[0]
        valid_patch_count = 1
    
    # Average the probabilities from all valid patches and normalize them
    averaged_probabilities = summed_probabilities / valid_patch_count
    normalized_probabilities = averaged_probabilities / np.sum(averaged_probabilities)
    
    final_prediction_index = np.argmax(normalized_probabilities)
    final_prediction = class_names[final_prediction_index]
    prob_dict = {cls: float(normalized_probabilities[i]) for i, cls in enumerate(class_names)}
    progress_log += "Classification completed.\n"
    
    print(progress_log)
    print(prob_dict)
    
    # Create an annotated image with transparent overlays
    annotated_image = resized_image.copy()
    overlay = annotated_image.copy()
    alpha = 0.4  # Transparency factor
    
    # Define overlay colors in BGR for each class (adjust as desired)
    color_map = {
        'wood': (0, 255, 255),   # Yellow (BGR format)
        'brick': (0, 0, 255),    # Red (BGR format)
        'stone': (128, 128, 128) # Gray (BGR format)
}

    for (x, y, w, h, label) in overlays_list:
        label_lower = label.lower()
        if prob_dict[label] < 0.2:
            continue
        color = color_map.get(label_lower, (0, 255, 0))  # Default to green if unknown
        cv2.rectangle(overlay, (x, y), (x+w, y+h), color, thickness=-1)
    
    # Blend the overlay with the original image
    annotated_image = cv2.addWeighted(overlay, alpha, annotated_image, 1 - alpha, 0)
    # Convert annotated image from BGR to RGB for Gradio display
    annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
    
    return final_prediction, prob_dict, annotated_image_rgb

# ---------------------------------------------------------------------
# Gradio Interface Setup using file paths and model selection
# ---------------------------------------------------------------------
if __name__ == "__main__":
    model_filename = "./svm_rf_combined.pkl"  # Adjust filename as needed
    load_model(model_filename)
    
    # Create a dropdown for model selection.
    model_choices = list(models.keys()) if models else ['svm', 'rf', 'combined']
    
    iface = gr.Interface(
        fn=classify_new_image,
        inputs=[
            gr.Image(type="filepath", label="Input Image"),
            gr.Dropdown(choices=model_choices, label="Select Model", value=model_choices[0])
        ],
        outputs=[
            gr.Label(label="Predicted Class"),
            gr.Label(label="Probabilities"),
            gr.Image(label="Annotated Image")
        ],
        title="Stone, Wood, Brick Classifier",
        description=("Upload an image and select a classifier model (svm, rf, combined) to classify it.\n\n"
                     "The image is processed by subdividing it into patches and aggregating the predictions. "
                     "Transparent overlays are drawn on detected objects. Progress logs are printed to the terminal.")
    )
    iface.launch(share=True)